{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/7_train_density.ipynb)\n",
    "# Learning a Reward Function using Kernel Density\n",
    "\n",
    "This demo shows how to train a `Pendulum` agent (exciting!) with our simple density-based imitation learning baselines. `DensityTrainer` has a few interesting parameters, but the key ones are:\n",
    "\n",
    "1. `density_type`: this governs whether density is measured on $(s,s')$ pairs (`db.STATE_STATE_DENSITY`), $(s,a)$ pairs (`db.STATE_ACTION_DENSITY`), or single states (`db.STATE_DENSITY`).\n",
    "2. `is_stationary`: determines whether a separate density model is used for each time step $t$ (`False`), or the same model is used for transitions at all times (`True`).\n",
    "3. `standardise_inputs`: if `True`, each dimension of the agent state vectors will be normalised to have zero mean and unit variance over the training dataset. This can be useful when not all elements of the demonstration vector are on the same scale, or when some elements have too wide a variation to be captured by the fixed kernel width (1 for Gaussian kernel).\n",
    "4. `kernel`: changes the kernel used for non-parametric density estimation. `gaussian` and `exponential` are the best bets; see the [sklearn docs](https://scikit-learn.org/stable/modules/density.html#kernel-density) for the rest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pprint\n",
    "\n",
    "from imitation.algorithms import density as db\n",
    "from imitation.data import types\n",
    "from imitation.util import util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set FAST = False for longer training. Use True for testing and CI.\n",
    "FAST = True\n",
    "\n",
    "if FAST:\n",
    "    N_VEC = 1\n",
    "    N_TRAJECTORIES = 1\n",
    "    N_ITERATIONS = 1\n",
    "    N_RL_TRAIN_STEPS = 100\n",
    "\n",
    "else:\n",
    "    N_VEC = 8\n",
    "    N_TRAJECTORIES = 10\n",
    "    N_ITERATIONS = 10\n",
    "    N_RL_TRAIN_STEPS = 100_000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.policies.serialize import load_policy\n",
    "from stable_baselines3.common.policies import ActorCriticPolicy\n",
    "from stable_baselines3 import PPO\n",
    "from imitation.data import rollout\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "from imitation.data.wrappers import RolloutInfoWrapper\n",
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "\n",
    "SEED = 42\n",
    "\n",
    "rng = np.random.default_rng(seed=SEED)\n",
    "env_name = \"Pendulum-v1\"\n",
    "rollout_env = DummyVecEnv(\n",
    "    [lambda: RolloutInfoWrapper(gym.make(env_name)) for _ in range(N_VEC)]\n",
    ")\n",
    "expert = load_policy(\n",
    "    \"ppo-huggingface\",\n",
    "    organization=\"HumanCompatibleAI\",\n",
    "    env_name=env_name,\n",
    "    venv=rollout_env,\n",
    ")\n",
    "rollouts = rollout.rollout(\n",
    "    expert,\n",
    "    rollout_env,\n",
    "    rollout.make_sample_until(min_timesteps=2000, min_episodes=57),\n",
    "    rng=rng,\n",
    ")\n",
    "\n",
    "env = util.make_vec_env(env_name, n_envs=N_VEC, rng=rng)\n",
    "\n",
    "\n",
    "imitation_trainer = PPO(\n",
    "    ActorCriticPolicy, env, learning_rate=3e-4, gamma=0.95, ent_coef=1e-4, n_steps=2048\n",
    ")\n",
    "density_trainer = db.DensityAlgorithm(\n",
    "    venv=env,\n",
    "    rng=rng,\n",
    "    demonstrations=rollouts,\n",
    "    rl_algo=imitation_trainer,\n",
    "    density_type=db.DensityType.STATE_ACTION_DENSITY,\n",
    "    is_stationary=True,\n",
    "    kernel=\"gaussian\",\n",
    "    kernel_bandwidth=0.4,  # found using divination & some palm reading\n",
    "    standardise_inputs=True,\n",
    ")\n",
    "density_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate the expert\n",
    "expert_rewards, _ = evaluate_policy(expert, env, 100, return_episode_rewards=True)\n",
    "\n",
    "# evaluate the learner before training\n",
    "learner_rewards_before_training, _ = evaluate_policy(\n",
    "    density_trainer.policy, env, 100, return_episode_rewards=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_stats(density_trainer, n_trajectories, epoch=\"\"):\n",
    "    stats = density_trainer.test_policy(n_trajectories=n_trajectories)\n",
    "    print(\"True reward function stats:\")\n",
    "    pprint.pprint(stats)\n",
    "    stats_im = density_trainer.test_policy(\n",
    "        true_reward=False,\n",
    "        n_trajectories=n_trajectories,\n",
    "    )\n",
    "    print(f\"Imitation reward function stats, epoch {epoch}:\")\n",
    "    pprint.pprint(stats_im)\n",
    "\n",
    "\n",
    "novice_stats = density_trainer.test_policy(n_trajectories=N_TRAJECTORIES)\n",
    "print(\"Stats before training:\")\n",
    "print_stats(density_trainer, 1)\n",
    "\n",
    "print(\"Starting the training!\")\n",
    "for i in range(N_ITERATIONS):\n",
    "    density_trainer.train_policy(N_RL_TRAIN_STEPS)\n",
    "    print_stats(density_trainer, 1, epoch=str(i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate the learner after training\n",
    "learner_rewards_after_training, _ = evaluate_policy(\n",
    "    density_trainer.policy, env, 100, return_episode_rewards=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are the final results. If you set `FAST = False` in one of the initial cells, you should see that performance after training approaches that of an expert."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Mean expert reward:\", np.mean(expert_rewards))\n",
    "print(\"Mean reward before training:\", np.mean(learner_rewards_before_training))\n",
    "print(\"Mean reward after training:\", np.mean(learner_rewards_after_training))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('imitation')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
